# Copyright (c) 2025 Apple Inc. Licensed under MIT License.

import os
from typing import Unpack

import streamlit.components.v1 as components

from .options import EmbeddingAtlasOptions, make_embedding_atlas_props

parent_dir = os.path.dirname(os.path.abspath(__file__))
build_dir = os.path.join(parent_dir, "widget_static/streamlit")
_embedding_atlas = components.declare_component("embedding_atlas", path=build_dir)


def embedding_atlas(
    data_frame,
    *,
    key=None,
    **options: Unpack[EmbeddingAtlasOptions],
) -> dict:
    """
    Create an Embedding Atlas widget in Streamlit.

    Args:
        data_frame:
            The data frame to visualize.

        x:
            The column name for X axis in the embedding.

        y:
            The column name for Y axis in the embedding.

        text:
            The column name for the textual data.

        neighbors:
            The column name containing precomputed K-nearest neighbors for each point.
            Each value in the column should be a dictionary with the format:
            ``{ "ids": [id1, id2, ...], "distances": [distance1, distance2, ...] }``.

            - ``"ids"`` should be an array of row ids of the neighbors
              (if ``row_id`` is specified, match the value in row_id, otherwise use the row index),
              sorted by distance.
            - ``"distances"`` should contain the corresponding distances to each neighbor.

        labels:
            Labels for the embedding view. Set to string ``"automatic"`` to generate labels automatically, or ``"disabled"`` to disable auto labels.
            Automatic labels are generated by clustering the 2D density distribution and selecting
            representative keywords using TF-IDF ranking.
            You can also pass in a list of labels. Each label must contain ``x`` and ``y`` coordinates
            and ``text`` for the label content. Optionally, you may specify an integer ``level`` to roughly
            control the zoom level where the label appears, and `priority` for the label's priority.
            Higher priority labels have a better chance to appear when multiple labels overlap.

        stop_words:
            Stop words for automatic label generation.

        point_size:
            Override the default point size for the embedding view.

        show_table:
            Whether to display the data table when the widget opens.

        show_charts:
            Whether to display charts when the widget opens.

        show_embedding:
            Whether to display the embedding view when the widget opens.

        key:
            The key of the Streamlit widget.

    Returns:
        A ``dict`` with the following key:

        - predicate: the SQL predicate for the current selection in the widget.
    """

    props = make_embedding_atlas_props(**options)

    return _embedding_atlas(data_frame=data_frame, props=props, key=key, default={})
